import numpy as np
import matplotlib.pyplot as plt
import cv2
import sys


def poly_cut(img, pts):
    '''This function cuts out a specifil plot from a full field image, its arguments are the full field image (img)
    and the points of the plot bounding polygon (pts), it returns the cropped image and the coordinates of the upper
    left corner of the smaller image in the full image'''
    points = np.array(pts, dtype=np.int32)
    points.reshape((-1, 2))
    points[points < 0] = 0
    rect = cv2.boundingRect(points)
    cropped = img[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
    points = points - np.array([rect[0], rect[1]])
    height = cropped.shape[0]
    width = cropped.shape[1]
    mask = np.zeros((height, width), dtype=np.uint8)
    tmp = cv2.fillPoly(mask, np.int32([points]), (1))
    res = cv2.bitwise_and(cropped, cropped, mask=mask)
    return res, rect[0], rect[1]


def extractHueSat(imRGB,list_bds):
    '''This function takes in an RGB image, and a list with four elements: [hl,hh, sl, sh],
     where hl and sl are the lower bounds for the hue and saturation channels respectively,
     and hh, sh, are the higher bounds for the respective channels.
    The function returns a filtered RGB image with the color of all pixels outside the
    hue (hl,hh) and saturation (sl,sh) ranges set to black (0,0,0)
    '''
    d1h, d2h, d1s, d2s = list_bds
    # image converted to HSV format
    im1hsv = cv2.cvtColor(imRGB,cv2.COLOR_RGB2HSV)
    # setting pixels outside ranges to (0,0,0) in the RGB image
    im1flt = np.copy(imRGB)
    im1flt[np.logical_or(im1hsv[:,:,0]<=d1h,im1hsv[:,:,0]>=d2h)]=(0,0,0)
    im1flt[np.logical_or(im1hsv[:,:,1]<=d1s,im1hsv[:,:,1]>=d2s)]=(0,0,0)
    return im1flt


def count_soil_canopy_pixels(binary_imageRGB):
    '''This functions considers an RGB image, previously binarized (binary_imageRGB),
     and counts the pixels relative to canopy and soil'''
    # we make sure that the image array is 2-dimensional
    try:
        image_2d=np.mean(binary_imageRGB, axis=2)
    except:
        image_2d=binary_imageRGB

    tot_num_pixels=image_2d.shape[0]*image_2d.shape[1]
    canopy_pixels=len(np.where(image_2d!=0)[0])
    soil_pixels=tot_num_pixels-canopy_pixels
    check=soil_pixels==len(np.where(image_2d==0)[0])
    if check:
        return soil_pixels, canopy_pixels
    else:
        print("Something doesn't add up ")


def binarize(array):
    '''this function takes a 3-dimensional array, computes the average of the 3 color channels,
    and returns a 2-dimensional array'''
    bin=np.mean(array, axis=2)
    return bin


def find_px_to_cm2_coeff( polys, field):
    '''This function finds the conversion factor needed to transform an
    area measurement in pixels to cm^2. ATTENTION: This conversion uses
    project specific measurements for the inter-ridge distance in the field. '''
    polygons=polys
    # In this loop we collect all inter-ridge distances along the field in pixels
    firr=np.array([])
    secc=np.array([])
    for key, value in polygons.items():
        #compute the average inter-ridge distance for the polygons in pixels
        first=np.linalg.norm(value[1:4]-value[2:5], axis=1)
        second=np.linalg.norm(value[7:10]-value[8:11], axis=1)
        mean_first=np.mean(first)
        mean_second=np.mean(second)
        firr=np.append(firr,mean_first)
        secc=np.append(secc,mean_second)

    # pixel_measurement is the field-average interridge distance in pixels
    pixel_measurement= np.mean((np.mean(firr), np.mean(secc)))
    #cm_meas is the known true distance as measured in cm on the planting device
    cm_meas={"M":74, 'V':75, 'S':75}
    pixels_for_1cm=cm_meas[field]/pixel_measurement
    cm_in_1px=1/pixels_for_1cm
    area_1_pixel=cm_in_1px**2
    return area_1_pixel

'''+++++++++++++++++++++++++++++++++++++ INPUT the desired parameters +++++++++++++++++++++++++++++++++++++++++++++++'''
'''Set parameters: field, year, and hue_saturation_ranges as defined in the protocol'''
field='M'
year='2020'
hue_saturation_ranges = [40,68,86,162]
chosen_field=f'{field}_{year}'
'''++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'''

'''Load the field image for the chosen year and field'''
image_path=f'./Images/{chosen_field}.npy'
field_image = np.load(image_path)

'''Load plot boundaries relative to the image, saved as a dictionary in a numpy array'''
polygon_path = f'./Polygons/{chosen_field}_polygons.npy'
polygons = np.load(polygon_path, allow_pickle=True).item()

'''Choose a specific plot and extract the cropped image'''
all_plot_numbers = range(1,721)
plot_number = 599
boundary_of_plot = polygons[plot_number]
cropped_plot_image,r0,r1 = poly_cut(field_image, boundary_of_plot)

'''Convert crop to binary image'''
blacksoil_imageRGB = extractHueSat(cropped_plot_image, hue_saturation_ranges)

'''Count pixels relative to the canopy and pixels relative to the soil'''
binaryimage = binarize(blacksoil_imageRGB)
soil,canopy = count_soil_canopy_pixels(blacksoil_imageRGB)

'''Convert the meaurement unit from pixels to cm^2'''

pixel_conversion_factor = find_px_to_cm2_coeff(polygons, field)
# We divide the total plot canopy area by 24 to obtain the canopy of the average plant in the plot
avg_canopy_area = ((canopy/24) * pixel_conversion_factor)


'''Visualize the process'''
plt.ion()
plt.figure(1)
plt.clf()
fig,ax = plt.subplots(nrows=3,ncols=1,num=1)

ax[0].imshow(field_image)
ax[0].set_axis_off()
for plot_num, polygon in polygons.items():
    rotarr = (np.array(polygon))
    if plot_num==plot_number:
        ax[0].plot(rotarr[:, 0], rotarr[:, 1], color='tab:red', zorder=800)
    else:
        ax[0].plot(rotarr[:, 0], rotarr[:, 1], color='silver', alpha=0.5)
ax[0].set_title('orthophoto with plot boundaries')

ax[1].imshow(cropped_plot_image)
ax[1].set_axis_off()
ax[1].set_title('plot No. '+str(plot_number))

ax[2].imshow(binaryimage)
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_title('canopy segmentation in plot No. '+str(plot_number))
ax[2].set_xlabel('soil = '+str(soil)+' px., canopy = '+str(canopy)+' px., average plant canopy area = '+str(int(avg_canopy_area))+r' cm$^{2}$')
